{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analyzing out-of-this world data, Part 2\n",
    "Using data collected from the Open Exoplanet Catalogue database: https://github.com/OpenExoplanetCatalogue/open_exoplanet_catalogue/\n",
    "\n",
    "## Data License\n",
    "Copyright (C) 2012 Hanno Rein\n",
    "\n",
    "Permission is hereby granted, free of charge, to any person obtaining a copy of this database and associated scripts (the \"Database\"), to deal in the Database without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Database, and to permit persons to whom the Database is furnished to do so, subject to the following conditions:\n",
    "\n",
    "The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Database. A reference to the Database shall be included in all scientific publications that make use of the Database.\n",
    "\n",
    "THE DATABASE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE DATABASE OR THE USE OR OTHER DEALINGS IN THE DATABASE."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "planets = pd.read_csv('data/planets.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we completed our EDA in the [planets_ml.ipynb notebook for last chapter](https://github.com/stefmolin/Hands-On-Data-Analysis-with-Pandas/blob/master/ch_09/planets_ml.ipynb), we will just look at the first 5 rows to refresh our memory of the data rather than repeating the EDA here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>period</th>\n",
       "      <th>name</th>\n",
       "      <th>eccentricity</th>\n",
       "      <th>description</th>\n",
       "      <th>discoverymethod</th>\n",
       "      <th>periastrontime</th>\n",
       "      <th>lastupdate</th>\n",
       "      <th>semimajoraxis</th>\n",
       "      <th>mass</th>\n",
       "      <th>periastron</th>\n",
       "      <th>list</th>\n",
       "      <th>discoveryyear</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>326.03</td>\n",
       "      <td>11 Com b</td>\n",
       "      <td>0.231</td>\n",
       "      <td>11 Com b is a brown dwarf-mass companion to th...</td>\n",
       "      <td>RV</td>\n",
       "      <td>2452899.60</td>\n",
       "      <td>15/09/20</td>\n",
       "      <td>1.290</td>\n",
       "      <td>19.400</td>\n",
       "      <td>94.800</td>\n",
       "      <td>Confirmed planets</td>\n",
       "      <td>2008.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>516.22</td>\n",
       "      <td>11 UMi b</td>\n",
       "      <td>0.080</td>\n",
       "      <td>11 Ursae Minoris is a star located in the cons...</td>\n",
       "      <td>RV</td>\n",
       "      <td>2452861.04</td>\n",
       "      <td>15/09/20</td>\n",
       "      <td>1.540</td>\n",
       "      <td>11.200</td>\n",
       "      <td>117.630</td>\n",
       "      <td>Confirmed planets</td>\n",
       "      <td>2009.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>185.84</td>\n",
       "      <td>14 And b</td>\n",
       "      <td>0.000</td>\n",
       "      <td>14 Andromedae is an evolved star in the conste...</td>\n",
       "      <td>RV</td>\n",
       "      <td>2452861.40</td>\n",
       "      <td>15/09/20</td>\n",
       "      <td>0.830</td>\n",
       "      <td>4.800</td>\n",
       "      <td>0.000</td>\n",
       "      <td>Confirmed planets</td>\n",
       "      <td>2008.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1766.00</td>\n",
       "      <td>14 Her b</td>\n",
       "      <td>0.359</td>\n",
       "      <td>The star 14 Herculis is only 59 light years aw...</td>\n",
       "      <td>RV</td>\n",
       "      <td>NaN</td>\n",
       "      <td>15/09/21</td>\n",
       "      <td>2.864</td>\n",
       "      <td>4.975</td>\n",
       "      <td>22.230</td>\n",
       "      <td>Confirmed planets</td>\n",
       "      <td>2002.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>9886.00</td>\n",
       "      <td>14 Her c</td>\n",
       "      <td>0.184</td>\n",
       "      <td>14 Her c is the second companion in the system...</td>\n",
       "      <td>RV</td>\n",
       "      <td>NaN</td>\n",
       "      <td>15/09/21</td>\n",
       "      <td>9.037</td>\n",
       "      <td>7.679</td>\n",
       "      <td>189.076</td>\n",
       "      <td>Controversial</td>\n",
       "      <td>2006.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    period      name  eccentricity  \\\n",
       "0   326.03  11 Com b         0.231   \n",
       "1   516.22  11 UMi b         0.080   \n",
       "2   185.84  14 And b         0.000   \n",
       "3  1766.00  14 Her b         0.359   \n",
       "4  9886.00  14 Her c         0.184   \n",
       "\n",
       "                                         description discoverymethod  \\\n",
       "0  11 Com b is a brown dwarf-mass companion to th...              RV   \n",
       "1  11 Ursae Minoris is a star located in the cons...              RV   \n",
       "2  14 Andromedae is an evolved star in the conste...              RV   \n",
       "3  The star 14 Herculis is only 59 light years aw...              RV   \n",
       "4  14 Her c is the second companion in the system...              RV   \n",
       "\n",
       "   periastrontime lastupdate  semimajoraxis    mass  periastron  \\\n",
       "0      2452899.60   15/09/20          1.290  19.400      94.800   \n",
       "1      2452861.04   15/09/20          1.540  11.200     117.630   \n",
       "2      2452861.40   15/09/20          0.830   4.800       0.000   \n",
       "3             NaN   15/09/21          2.864   4.975      22.230   \n",
       "4             NaN   15/09/21          9.037   7.679     189.076   \n",
       "\n",
       "                list  discoveryyear  \n",
       "0  Confirmed planets         2008.0  \n",
       "1  Confirmed planets         2009.0  \n",
       "2  Confirmed planets         2008.0  \n",
       "3  Confirmed planets         2002.0  \n",
       "4      Controversial         2006.0  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "planets.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predicting Length of Year (Period)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "data = planets[\n",
    "    ['semimajoraxis', 'period', 'mass', 'eccentricity']\n",
    "].dropna()\n",
    "planets_X = data[['semimajoraxis', 'mass', 'eccentricity']]\n",
    "planets_y = data.period\n",
    "\n",
    "pl_X_train, pl_X_test, pl_y_train, pl_y_test = train_test_split(\n",
    "    planets_X, planets_y, test_size=0.25, random_state=0\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Hyperparameter Tuning with GridSearchCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.metrics import make_scorer, mean_squared_error\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "model_pipeline = Pipeline([\n",
    "    ('scale', StandardScaler()),\n",
    "    ('lr', LinearRegression())\n",
    "])\n",
    "\n",
    "search_space = {\n",
    "    'scale__with_mean' : [True, False], 'scale__with_std' : [True, False],\n",
    "    'lr__fit_intercept': [True, False], 'lr__normalize' : [True, False]\n",
    "}\n",
    "grid = GridSearchCV(\n",
    "    model_pipeline, search_space, cv=5,\n",
    "    scoring={\n",
    "        'r_squared': 'r2', \n",
    "        'mse' : 'neg_mean_squared_error', \n",
    "        'mae' : 'neg_mean_absolute_error',\n",
    "        'rmse' : make_scorer(lambda x, y: np.sqrt(mean_squared_error(x, y)))\n",
    "    }, refit='mae', iid=False\n",
    ").fit(pl_X_train, pl_y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can check the best score from the cross validation and the best hyperparameters in the search space:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Best parameters (CV score=-1605.54):\n",
      "{'lr__fit_intercept': False, 'lr__normalize': True, 'scale__with_mean': False, 'scale__with_std': True}\n"
     ]
    }
   ],
   "source": [
    "print('Best parameters (CV score=%.2f):\\n%s' % (\n",
    "    grid.best_score_, grid.best_params_\n",
    "))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This reduced our MAE by over 150 Earth days ([compared to chapter 9's model](https://github.com/stefmolin/Hands-On-Data-Analysis-with-Pandas/blob/master/ch_09/planets_ml.ipynb)):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1235.4924651855556"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.metrics import mean_absolute_error\n",
    "mean_absolute_error(pl_y_test, grid.predict(pl_X_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Can a decision tree tell us what features are important?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('semimajoraxis', 0.9909470467126007),\n",
       " ('mass', 0.008457963837110774),\n",
       " ('eccentricity', 0.0005949894502886271)]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeRegressor\n",
    "\n",
    "dt = DecisionTreeRegressor(random_state=0).fit(pl_X_train, pl_y_train)\n",
    "[(col, coef) for col, coef in zip(\n",
    "    pl_X_train.columns, dt.feature_importances_\n",
    ")]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can visualize the decisions the tree is making:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\r\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\r\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\r\n",
       "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\r\n",
       " -->\r\n",
       "<!-- Title: Tree Pages: 1 -->\r\n",
       "<svg width=\"1233pt\" height=\"477pt\"\r\n",
       " viewBox=\"0.00 0.00 1233.00 477.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\r\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 473)\">\r\n",
       "<title>Tree</title>\r\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-473 1229,-473 1229,4 -4,4\"/>\r\n",
       "<!-- 0 -->\r\n",
       "<g id=\"node1\" class=\"node\"><title>0</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"881,-469 724,-469 724,-401 881,-401 881,-469\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"802.5\" y=\"-453.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">semimajoraxis &lt;= 34.743</text>\r\n",
       "<text text-anchor=\"middle\" x=\"802.5\" y=\"-438.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 102060175.205</text>\r\n",
       "<text text-anchor=\"middle\" x=\"802.5\" y=\"-423.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 783</text>\r\n",
       "<text text-anchor=\"middle\" x=\"802.5\" y=\"-408.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 1452.442</text>\r\n",
       "</g>\r\n",
       "<!-- 1 -->\r\n",
       "<g id=\"node2\" class=\"node\"><title>1</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"789,-365 652,-365 652,-297 789,-297 789,-365\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"720.5\" y=\"-349.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">semimajoraxis &lt;= 7.6</text>\r\n",
       "<text text-anchor=\"middle\" x=\"720.5\" y=\"-334.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 6017695.927</text>\r\n",
       "<text text-anchor=\"middle\" x=\"720.5\" y=\"-319.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 780</text>\r\n",
       "<text text-anchor=\"middle\" x=\"720.5\" y=\"-304.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 870.918</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;1 -->\r\n",
       "<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M775.877,-400.884C768.788,-392.065 761.042,-382.43 753.65,-373.235\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"756.264,-370.9 747.27,-365.299 750.808,-375.286 756.264,-370.9\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"744.569\" y=\"-386.453\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">True</text>\r\n",
       "</g>\r\n",
       "<!-- 16 -->\r\n",
       "<g id=\"node17\" class=\"node\"><title>16</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"957.5,-365 811.5,-365 811.5,-297 957.5,-297 957.5,-365\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"884.5\" y=\"-349.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mass &lt;= 2.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"884.5\" y=\"-334.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 2124859022.429</text>\r\n",
       "<text text-anchor=\"middle\" x=\"884.5\" y=\"-319.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"884.5\" y=\"-304.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 152648.743</text>\r\n",
       "</g>\r\n",
       "<!-- 0&#45;&gt;16 -->\r\n",
       "<g id=\"edge16\" class=\"edge\"><title>0&#45;&gt;16</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M829.123,-400.884C836.212,-392.065 843.958,-382.43 851.35,-373.235\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"854.192,-375.286 857.73,-365.299 848.736,-370.9 854.192,-375.286\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"860.431\" y=\"-386.453\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">False</text>\r\n",
       "</g>\r\n",
       "<!-- 2 -->\r\n",
       "<g id=\"node3\" class=\"node\"><title>2</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"470.5,-261 320.5,-261 320.5,-193 470.5,-193 470.5,-261\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"395.5\" y=\"-245.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">semimajoraxis &lt;= 2.889</text>\r\n",
       "<text text-anchor=\"middle\" x=\"395.5\" y=\"-230.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 1413673.905</text>\r\n",
       "<text text-anchor=\"middle\" x=\"395.5\" y=\"-215.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 767</text>\r\n",
       "<text text-anchor=\"middle\" x=\"395.5\" y=\"-200.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 640.004</text>\r\n",
       "</g>\r\n",
       "<!-- 1&#45;&gt;2 -->\r\n",
       "<g id=\"edge2\" class=\"edge\"><title>1&#45;&gt;2</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M651.973,-308.493C602.007,-292.811 533.979,-271.461 480.541,-254.69\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"481.329,-251.269 470.74,-251.614 479.233,-257.948 481.329,-251.269\"/>\r\n",
       "</g>\r\n",
       "<!-- 9 -->\r\n",
       "<g id=\"node10\" class=\"node\"><title>9</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"799,-261 642,-261 642,-193 799,-193 799,-261\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"720.5\" y=\"-245.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">semimajoraxis &lt;= 15.975</text>\r\n",
       "<text text-anchor=\"middle\" x=\"720.5\" y=\"-230.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 88898325.341</text>\r\n",
       "<text text-anchor=\"middle\" x=\"720.5\" y=\"-215.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 13</text>\r\n",
       "<text text-anchor=\"middle\" x=\"720.5\" y=\"-200.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 14494.821</text>\r\n",
       "</g>\r\n",
       "<!-- 1&#45;&gt;9 -->\r\n",
       "<g id=\"edge9\" class=\"edge\"><title>1&#45;&gt;9</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M720.5,-296.884C720.5,-288.778 720.5,-279.982 720.5,-271.472\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"724,-271.299 720.5,-261.299 717,-271.299 724,-271.299\"/>\r\n",
       "</g>\r\n",
       "<!-- 3 -->\r\n",
       "<g id=\"node4\" class=\"node\"><title>3</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"251.5,-157 101.5,-157 101.5,-89 251.5,-89 251.5,-157\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"176.5\" y=\"-141.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">semimajoraxis &lt;= 1.455</text>\r\n",
       "<text text-anchor=\"middle\" x=\"176.5\" y=\"-126.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 182863.046</text>\r\n",
       "<text text-anchor=\"middle\" x=\"176.5\" y=\"-111.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 682</text>\r\n",
       "<text text-anchor=\"middle\" x=\"176.5\" y=\"-96.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 283.248</text>\r\n",
       "</g>\r\n",
       "<!-- 2&#45;&gt;3 -->\r\n",
       "<g id=\"edge3\" class=\"edge\"><title>2&#45;&gt;3</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M324.397,-192.884C302.868,-182.856 279.072,-171.773 256.975,-161.482\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"258.164,-158.174 247.621,-157.125 255.208,-164.52 258.164,-158.174\"/>\r\n",
       "</g>\r\n",
       "<!-- 6 -->\r\n",
       "<g id=\"node7\" class=\"node\"><title>6</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"467,-157 324,-157 324,-89 467,-89 467,-157\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"395.5\" y=\"-141.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">semimajoraxis &lt;= 4.47</text>\r\n",
       "<text text-anchor=\"middle\" x=\"395.5\" y=\"-126.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 2074320.063</text>\r\n",
       "<text text-anchor=\"middle\" x=\"395.5\" y=\"-111.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 85</text>\r\n",
       "<text text-anchor=\"middle\" x=\"395.5\" y=\"-96.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 3502.452</text>\r\n",
       "</g>\r\n",
       "<!-- 2&#45;&gt;6 -->\r\n",
       "<g id=\"edge6\" class=\"edge\"><title>2&#45;&gt;6</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M395.5,-192.884C395.5,-184.778 395.5,-175.982 395.5,-167.472\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"399,-167.299 395.5,-157.299 392,-167.299 399,-167.299\"/>\r\n",
       "</g>\r\n",
       "<!-- 4 -->\r\n",
       "<g id=\"node5\" class=\"node\"><title>4</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"99,-53 0,-53 0,-0 99,-0 99,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"49.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 29434.8</text>\r\n",
       "<text text-anchor=\"middle\" x=\"49.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 560</text>\r\n",
       "<text text-anchor=\"middle\" x=\"49.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 111.93</text>\r\n",
       "</g>\r\n",
       "<!-- 3&#45;&gt;4 -->\r\n",
       "<g id=\"edge4\" class=\"edge\"><title>3&#45;&gt;4</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M132.082,-88.9485C119.173,-79.3431 105.104,-68.8747 92.297,-59.345\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"94.2,-56.3984 84.088,-53.2367 90.0213,-62.0143 94.2,-56.3984\"/>\r\n",
       "</g>\r\n",
       "<!-- 5 -->\r\n",
       "<g id=\"node6\" class=\"node\"><title>5</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"236,-53 117,-53 117,-0 236,-0 236,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"176.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 134017.827</text>\r\n",
       "<text text-anchor=\"middle\" x=\"176.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 122</text>\r\n",
       "<text text-anchor=\"middle\" x=\"176.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 1069.623</text>\r\n",
       "</g>\r\n",
       "<!-- 3&#45;&gt;5 -->\r\n",
       "<g id=\"edge5\" class=\"edge\"><title>3&#45;&gt;5</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M176.5,-88.9485C176.5,-80.7153 176.5,-71.848 176.5,-63.4814\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"180,-63.2367 176.5,-53.2367 173,-63.2367 180,-63.2367\"/>\r\n",
       "</g>\r\n",
       "<!-- 7 -->\r\n",
       "<g id=\"node8\" class=\"node\"><title>7</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"367,-53 254,-53 254,-0 367,-0 367,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"310.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 877802.58</text>\r\n",
       "<text text-anchor=\"middle\" x=\"310.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 45</text>\r\n",
       "<text text-anchor=\"middle\" x=\"310.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 2561.063</text>\r\n",
       "</g>\r\n",
       "<!-- 6&#45;&gt;7 -->\r\n",
       "<g id=\"edge7\" class=\"edge\"><title>6&#45;&gt;7</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M365.771,-88.9485C357.543,-79.8005 348.61,-69.8697 340.374,-60.7126\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"342.939,-58.331 333.649,-53.2367 337.735,-63.0123 342.939,-58.331\"/>\r\n",
       "</g>\r\n",
       "<!-- 8 -->\r\n",
       "<g id=\"node9\" class=\"node\"><title>8</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"511.5,-53 385.5,-53 385.5,-0 511.5,-0 511.5,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"448.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 1301799.786</text>\r\n",
       "<text text-anchor=\"middle\" x=\"448.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 40</text>\r\n",
       "<text text-anchor=\"middle\" x=\"448.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 4561.514</text>\r\n",
       "</g>\r\n",
       "<!-- 6&#45;&gt;8 -->\r\n",
       "<g id=\"edge8\" class=\"edge\"><title>6&#45;&gt;8</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M414.037,-88.9485C418.911,-80.2579 424.181,-70.8608 429.099,-62.0917\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"432.227,-63.6707 434.066,-53.2367 426.121,-60.2465 432.227,-63.6707\"/>\r\n",
       "</g>\r\n",
       "<!-- 10 -->\r\n",
       "<g id=\"node11\" class=\"node\"><title>10</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"781.5,-157 655.5,-157 655.5,-89 781.5,-89 781.5,-157\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"718.5\" y=\"-141.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mass &lt;= 4.025</text>\r\n",
       "<text text-anchor=\"middle\" x=\"718.5\" y=\"-126.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 8969158.624</text>\r\n",
       "<text text-anchor=\"middle\" x=\"718.5\" y=\"-111.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 9</text>\r\n",
       "<text text-anchor=\"middle\" x=\"718.5\" y=\"-96.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 10127.963</text>\r\n",
       "</g>\r\n",
       "<!-- 9&#45;&gt;10 -->\r\n",
       "<g id=\"edge10\" class=\"edge\"><title>9&#45;&gt;10</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M719.851,-192.884C719.692,-184.778 719.519,-175.982 719.352,-167.472\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"722.848,-167.229 719.153,-157.299 715.85,-167.366 722.848,-167.229\"/>\r\n",
       "</g>\r\n",
       "<!-- 13 -->\r\n",
       "<g id=\"node14\" class=\"node\"><title>13</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"957,-157 800,-157 800,-89 957,-89 957,-157\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"878.5\" y=\"-141.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">semimajoraxis &lt;= 28.485</text>\r\n",
       "<text text-anchor=\"middle\" x=\"878.5\" y=\"-126.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 129293642.188</text>\r\n",
       "<text text-anchor=\"middle\" x=\"878.5\" y=\"-111.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 4</text>\r\n",
       "<text text-anchor=\"middle\" x=\"878.5\" y=\"-96.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 24320.25</text>\r\n",
       "</g>\r\n",
       "<!-- 9&#45;&gt;13 -->\r\n",
       "<g id=\"edge13\" class=\"edge\"><title>9&#45;&gt;13</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M771.798,-192.884C786.776,-183.214 803.274,-172.563 818.728,-162.587\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"820.686,-165.489 827.189,-157.125 816.889,-159.608 820.686,-165.489\"/>\r\n",
       "</g>\r\n",
       "<!-- 11 -->\r\n",
       "<g id=\"node12\" class=\"node\"><title>11</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"649,-53 530,-53 530,-0 649,-0 649,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"589.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 3283222.18</text>\r\n",
       "<text text-anchor=\"middle\" x=\"589.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"589.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 13252.557</text>\r\n",
       "</g>\r\n",
       "<!-- 10&#45;&gt;11 -->\r\n",
       "<g id=\"edge11\" class=\"edge\"><title>10&#45;&gt;11</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M673.382,-88.9485C660.27,-79.3431 645.98,-68.8747 632.971,-59.345\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"634.768,-56.3228 624.633,-53.2367 630.631,-61.9698 634.768,-56.3228\"/>\r\n",
       "</g>\r\n",
       "<!-- 12 -->\r\n",
       "<g id=\"node13\" class=\"node\"><title>12</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"793.5,-53 667.5,-53 667.5,-0 793.5,-0 793.5,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"730.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 4489814.222</text>\r\n",
       "<text text-anchor=\"middle\" x=\"730.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 6</text>\r\n",
       "<text text-anchor=\"middle\" x=\"730.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 8565.667</text>\r\n",
       "</g>\r\n",
       "<!-- 10&#45;&gt;12 -->\r\n",
       "<g id=\"edge12\" class=\"edge\"><title>10&#45;&gt;12</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M722.697,-88.9485C723.754,-80.6238 724.894,-71.6509 725.966,-63.2027\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"729.444,-63.598 727.232,-53.2367 722.5,-62.7161 729.444,-63.598\"/>\r\n",
       "</g>\r\n",
       "<!-- 14 -->\r\n",
       "<g id=\"node15\" class=\"node\"><title>14</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"945,-53 812,-53 812,-0 945,-0 945,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"878.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 81249496.222</text>\r\n",
       "<text text-anchor=\"middle\" x=\"878.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 3</text>\r\n",
       "<text text-anchor=\"middle\" x=\"878.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 29093.667</text>\r\n",
       "</g>\r\n",
       "<!-- 13&#45;&gt;14 -->\r\n",
       "<g id=\"edge14\" class=\"edge\"><title>13&#45;&gt;14</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M878.5,-88.9485C878.5,-80.7153 878.5,-71.848 878.5,-63.4814\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"882,-63.2367 878.5,-53.2367 875,-63.2367 882,-63.2367\"/>\r\n",
       "</g>\r\n",
       "<!-- 15 -->\r\n",
       "<g id=\"node16\" class=\"node\"><title>15</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"1068,-53 963,-53 963,-0 1068,-0 1068,-53\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"1015.5\" y=\"-37.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"1015.5\" y=\"-22.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"1015.5\" y=\"-7.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 10000.0</text>\r\n",
       "</g>\r\n",
       "<!-- 13&#45;&gt;15 -->\r\n",
       "<g id=\"edge15\" class=\"edge\"><title>13&#45;&gt;15</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M926.416,-88.9485C940.474,-79.2516 955.807,-68.6752 969.728,-59.073\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"971.944,-61.7959 978.189,-53.2367 967.97,-56.0337 971.944,-61.7959\"/>\r\n",
       "</g>\r\n",
       "<!-- 17 -->\r\n",
       "<g id=\"node18\" class=\"node\"><title>17</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"940.5,-253.5 828.5,-253.5 828.5,-200.5 940.5,-200.5 940.5,-253.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"884.5\" y=\"-238.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"884.5\" y=\"-223.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"884.5\" y=\"-208.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 90553.02</text>\r\n",
       "</g>\r\n",
       "<!-- 16&#45;&gt;17 -->\r\n",
       "<g id=\"edge17\" class=\"edge\"><title>16&#45;&gt;17</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M884.5,-296.884C884.5,-286.326 884.5,-274.597 884.5,-263.854\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"888,-263.52 884.5,-253.52 881,-263.52 888,-263.52\"/>\r\n",
       "</g>\r\n",
       "<!-- 18 -->\r\n",
       "<g id=\"node19\" class=\"node\"><title>18</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"1101.5,-261 961.5,-261 961.5,-193 1101.5,-193 1101.5,-261\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"1031.5\" y=\"-245.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">eccentricity &lt;= 0.175</text>\r\n",
       "<text text-anchor=\"middle\" x=\"1031.5\" y=\"-230.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 295379391.426</text>\r\n",
       "<text text-anchor=\"middle\" x=\"1031.5\" y=\"-215.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 2</text>\r\n",
       "<text text-anchor=\"middle\" x=\"1031.5\" y=\"-200.8\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 183696.605</text>\r\n",
       "</g>\r\n",
       "<!-- 16&#45;&gt;18 -->\r\n",
       "<g id=\"edge18\" class=\"edge\"><title>16&#45;&gt;18</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M932.227,-296.884C946.032,-287.304 961.227,-276.761 975.489,-266.864\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"977.541,-269.701 983.761,-261.125 973.55,-263.95 977.541,-269.701\"/>\r\n",
       "</g>\r\n",
       "<!-- 19 -->\r\n",
       "<g id=\"node20\" class=\"node\"><title>19</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"1087.5,-149.5 975.5,-149.5 975.5,-96.5 1087.5,-96.5 1087.5,-149.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"1031.5\" y=\"-134.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"1031.5\" y=\"-119.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"1031.5\" y=\"-104.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 166510.0</text>\r\n",
       "</g>\r\n",
       "<!-- 18&#45;&gt;19 -->\r\n",
       "<g id=\"edge19\" class=\"edge\"><title>18&#45;&gt;19</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1031.5,-192.884C1031.5,-182.326 1031.5,-170.597 1031.5,-159.854\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1035,-159.52 1031.5,-149.52 1028,-159.52 1035,-159.52\"/>\r\n",
       "</g>\r\n",
       "<!-- 20 -->\r\n",
       "<g id=\"node21\" class=\"node\"><title>20</title>\r\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"1225,-149.5 1106,-149.5 1106,-96.5 1225,-96.5 1225,-149.5\"/>\r\n",
       "<text text-anchor=\"middle\" x=\"1165.5\" y=\"-134.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">mse = 0.0</text>\r\n",
       "<text text-anchor=\"middle\" x=\"1165.5\" y=\"-119.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">samples = 1</text>\r\n",
       "<text text-anchor=\"middle\" x=\"1165.5\" y=\"-104.3\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">value = 200883.21</text>\r\n",
       "</g>\r\n",
       "<!-- 18&#45;&gt;20 -->\r\n",
       "<g id=\"edge20\" class=\"edge\"><title>18&#45;&gt;20</title>\r\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M1075.01,-192.884C1090.61,-181.006 1108.16,-167.646 1123.62,-155.876\"/>\r\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"1126.14,-158.362 1131.97,-149.52 1121.9,-152.792 1126.14,-158.362\"/>\r\n",
       "</g>\r\n",
       "</g>\r\n",
       "</svg>\r\n"
      ],
      "text/plain": [
       "<graphviz.files.Source at 0x1140d110>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.tree import export_graphviz\n",
    "import graphviz\n",
    "\n",
    "graphviz.Source(export_graphviz(\n",
    "    DecisionTreeRegressor(\n",
    "        max_depth=4, random_state=0\n",
    "    ).fit(pl_X_train, pl_y_train),\n",
    "    feature_names=pl_X_train.columns\n",
    "))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Regularization\n",
    "Ridge (L2) and LASSO (L1) are terms we can add to regression to penalize large coefficients. L2 uses squared coefficients and L1 uses the absolute value. LASSO tends to drive coeffiecients to zero and is therefore used for feature selection. We can combine the two in an elastic net. These help to reduce overfitting. Check the documentation for parameters that can be tuned:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ridge: 0.9302\n",
      "Lasso: 0.9298\n",
      "ElasticNet: 0.9375\n"
     ]
    }
   ],
   "source": [
    "from sklearn.linear_model import Ridge, Lasso, ElasticNet\n",
    "\n",
    "ridge, lasso, elastic = Ridge(), Lasso(), ElasticNet()\n",
    "\n",
    "for model in [ridge, lasso, elastic]:\n",
    "    model.fit(pl_X_train, pl_y_train)\n",
    "    print(\n",
    "        f'{model.__class__.__name__}: '\n",
    "        f'{model.score(pl_X_test, pl_y_test):.4}'\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
